import os


class ModelPlotSaver:
    def __init__(self,
                 plot_train_save_path: str,
                 plot_cm_save_path: str,
                 plot_pr_save_path: str,
                 **kwargs):
        """
        初始化模型可视化保存器

        :param plot_train_save_path: 训练过程可视化保存路径
        :param plot_cm_save_path: 训练过程混淆矩阵可视化保存路径
        :param plot_pr_save_path: 训练过程PR曲线可视化保存路径
        :param kwargs: 其他参数
        """
        self.plot_train_save_path = plot_train_save_path
        self.plot_cm_save_path = plot_cm_save_path
        self.plot_pr_save_path = plot_pr_save_path
        self.__dict__.update(kwargs)

        # 创建文件夹
        if not os.path.exists(os.path.dirname(self.plot_train_save_path)):
            os.makedirs(os.path.dirname(self.plot_train_save_path))
        if not os.path.exists(os.path.dirname(self.plot_cm_save_path)):
            os.makedirs(os.path.dirname(self.plot_cm_save_path))
        if not os.path.exists(os.path.dirname(self.plot_pr_save_path)):
            os.makedirs(os.path.dirname(self.plot_pr_save_path))